rm(list=ls())

if(!require(pacman)) install.packages("pacman")
pacman::p_load(rstudioapi,
               reticulate,
               keras,
               dplyr,
               rio,
               ds4psy,
               stringr, 
               conflicted,
               quanteda)


# Set Working directory
#setwd(dirname(getActiveDocumentContext()$path))

# Resolve package conflicts
conflict_prefer("import", "reticulate")

# Import data
df <- rio::import("speeches19_rep.Rdata")

df$speech_id <- rownames(df)

corp <- corpus(df, text_field = "speechContent")
corp <- corpus_reshape(corp, to = "sentences")



df2 <- quanteda::convert(corp, to = "data.frame")

rm(corp)
rm(df)

# Text vector
text2classify <- df2$text
text2classify <- str_trunc(text2classify, 1000)


# set up python
reticulate::use_python('C:/Users/jdiener/OneDrive/Documents/.virtualenvs/r-reticulate/Scripts/python.exe')
reticulate::py_module_available("transformers")
reticulate::py_module_available("torch")

# Set up transformers
transformers <- reticulate::import("transformers")

# set up Model and Tokenizer
MODEL <- transformers$AutoModelForSequenceClassification$from_pretrained("chkla/parlbert-topic-german",
                                                                         trust_remote_code = TRUE,
                                                                         max_length=512L)
TOKENIZER <- transformers$AutoTokenizer$from_pretrained("bert-base-german-cased",
                                                        use_fast = T,
                                                        model_max_length = 512L)

# Set up classifyer
classify <-
  transformers$pipeline(
    task = "text-classification",
    model = MODEL,
    tokenizer = TOKENIZER,
    framework = "pt"
  )

# Classify data
batch <- lapply(seq(0,777300,100), function(x) {x + seq_len(100)})
batch[[length(batch)]] <- seq(max(batch[[length(batch)-1]])+1, nrow(df2), 1)

classified_sentences <- vector(mode = "list", length = length(text2classify))

for(i in 1:length(batch)) {
  classified_sentences[batch[[i]]] <- classify(text2classify[batch[[i]]], batch_size = 80L)
  
  if(is_wholenumber(i/10)){
    saveRDS(classified_sentences, file = "backup.Rds")
  }
  
  print(paste0("Batch ", i, " of ", length(batch), ": ", round(i/length(batch)*100, 0), "%"))
}


# Process data
df_classified <- classified_sentences %>% 
  do.call(rbind, .) %>% 
  as.data.frame() %>% 
  transmute(label_text = unlist(label),
            score = unlist(score)) %>% 
  cbind(df2, .)

# Save data
saveRDS(df_classified, file = "speeches_19_sent_class_parlBert.Rds")


# Load confusion matrix to compare handcoded sample with classifier

load("confusion_matrix.RData")